from __future__ import annotations

import torch

from argparse import Namespace

class Aggregator:
    byz_client_idxs = set()

    def __init__(self, args: Namespace, byz_client_idxs: set[str]) -> None:
        self.args = args
        Aggregator.byz_client_idxs = byz_client_idxs

    @torch.no_grad()
    def __call__(self, client_messages: dict, knowledge: dict) -> dict[str, torch.Tensor]:
        raise Exception('specification of aggregator is needed')

    @staticmethod
    def get_updates(client_messages: dict) -> dict[str, dict[str, torch.Tensor]]:
        updates: dict[str, dict[str, torch.Tensor]] = {client_idx: client_messages[client_idx]['update'] for client_idx in client_messages}
        return updates